from typing import TypedDict, Dict, List, Annotated, Union, Optional
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from core.testAgent.EnvironmentService import EnvironmentService
import json
import operator
from langgraph.types import Send
from langgraph.pregel import RetryPolicy
# from ExecuteTools import compile_and_execute_test, simple_fix
from core.testAgent.RetrieverTools import (find_class, find_method_definition, find_variable_definition, find_method_calls,
                            find_method_usages, fuzzy_search, graph_retriever)
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from core.testAgent.LoggerManager import LoggerManager
import re
import os
from core.testAgent.Config import *
from core.testAgent.PromptTemplate import Method_Analyzer_Template, Test_Points_Init_Template, Planner_Init_State_Template, \
    Acceptance_Review_Template, Final_Test_Case_Fix_Template, Test_Points_Review_Template
from core.testAgent.Generator import generator_graph
import concurrent.futures


class PlannerState(TypedDict):
    envServer: EnvironmentService  # 环境服务
    package_name: str  # 包名
    method_id: int  # 被测函数 ID
    job_id: Optional[str]  # 进度跟踪 ID
    allow_human_assistance: bool  # 是否允许人工协助
    method_code: str  # 被测函数代码
    method_signature: str  # 被测函数签名
    start_line: int  # 被测函数起始行
    end_line: int  # 被测函数结束行
    class_name: str  # 所在类名
    full_method_name: str  # 全限定方法名
    method_summary: str  # 函数的主要功能/作用总结
    test_points: List[Dict]  # 生成的测试点列表
    test_cases: Annotated[List, operator.add]  # 生成的测试用例列表
    temp_test_cases: List  # 迭代时控制变量
    final_test_case: str  # 最终生成的测试用例
    test_result: str  # 测试结果
    test_report: str  # 测试报告
    # 交互历史（支持 LLM 处理）
    messages: Annotated[List[Union[dict, BaseMessage]], add_messages]


class SubState(TypedDict):
    envServer: EnvironmentService  # 环境服务
    allow_human_assistance: bool  # 是否允许人工协助
    job_id: Optional[str]
    package_name: str  # 包名
    method_code: str  # 被测函数代码
    method_signature: str  # 被测函数签名
    full_method_name: str  # 全限定方法名
    start_line: int  # 被测函数起始行
    end_line: int  # 被测函数结束行
    class_name: str  # 所在类名
    method_summary: str  # 函数的主要功能/作用总结
    requirement: Dict  # 生成的测试需求列表


def planner_init_state(state: PlannerState):
    init_prompt = Planner_Init_State_Template.invoke({})
    valid_prompt = init_prompt.to_messages()
    # ret = graph_retriever.delete_test_class_by_signature(state["method_signature"], state["full_method_name"])
    # LoggerManager().logger.info(f"Delete test class by signature: {state['method_signature']}\nresult: {ret}")
    return {"messages": valid_prompt, "test_cases": [], "temp_test_cases": []}


def methodAnalyzer(state: PlannerState):
    def extract_summary(content: str):
        match = re.search(r'```(.*?)```', content, re.DOTALL)
        return match.group(1) if match else None

    method_code = state["method_code"]
    method_signature = state["method_signature"]
    class_name = state["class_name"]
    prompt = Method_Analyzer_Template.invoke({"method_code": method_code})
    valid_prompt = prompt.to_messages()
    query_prompt = state["messages"]
    query_prompt.extend(valid_prompt)
    assert query_prompt is not None
    # print(query_prompt)
    result = llm.invoke(query_prompt)
    state['messages'].append(result)
    method_summary = ""
    try_times = 0
    max_try_times = 3
    while try_times < max_try_times:
        method_summary = extract_summary(result.content)
        if method_summary:
            valid_prompt.append(result)
            break
        else:
            try_times += 1
            feedback = HumanMessage(
                "Can not find triple backticks in the response, please provide the method summary in the triple backticks."
            )
            state['messages'].append(feedback)
            result = llm.invoke(state['messages'])
            state["messages"].append(result)
    graph_retriever.update_method(state["method_id"], method_summary)
    return {"messages": valid_prompt, "method_summary": method_summary}


def checkSummary(state: PlannerState):
    method_summary = state["method_summary"]
    if method_summary is None:
        return "methodAnalyzer"
    return "testPointsGenerator"


def testPointsGenerator(state: PlannerState):
    def extract_test_points(content: str):
        match = re.search(r'```json(.*?)```', content, re.DOTALL)
        return json.loads(match.group(1)) if match else None

    if len(state["test_cases"]) == 0:
        prompt = Test_Points_Init_Template.invoke({})
        valid_prompt = prompt.to_messages()
    else:
        test_cases_info = state['test_cases']
        prompt = Test_Points_Review_Template.invoke({"test_cases_info": test_cases_info})
        valid_prompt = prompt.to_messages()

    query_prompt = state["messages"]
    query_prompt.extend(valid_prompt)
    assert query_prompt is not None
    result = llm.invoke(query_prompt)
    # print(result)
    state['messages'].append(result)
    test_points = None
    try_times = 0
    max_try_times = 3
    while try_times < max_try_times:
        test_points = extract_test_points(result.content)
        if test_points:
            valid_prompt.append(result)
            break
        else:
            try_times += 1
            feedback = HumanMessage(
                "Can not find json data in the response, please provide the test points in the triple backticks."
                "And make sure the format is following the JSON format."
            )
            query_prompt.append(feedback)
            result = llm.invoke(query_prompt)
            print(result)
            query_prompt.append(result)
    return {"messages": valid_prompt, "test_points": test_points[:3]}


def testCaseGenerator(state: SubState):
    response = generator_graph.invoke({
        "envServer": state["envServer"],
        "allow_human_assistance": state.get("allow_human_assistance", True),
        "package_name": state["package_name"],
        "method_signature": state["method_signature"],
        "method_code": state["method_code"],
        "requirement": state["requirement"],
        "start_line": state["start_line"],
        "end_line": state["end_line"],
        "class_name": state["class_name"],
        "full_method_name": state["full_method_name"],
        "test_case": "",
        "method_summary": state["method_summary"],
        "generate_or_evaluation": "generation",
        "job_id": state.get("job_id"),
    })
    test_report = {
        "test_result": response['test_result'],
        "find_bug": response['find_bug'],
        "bug_report": response['bug_report'],
        "test_point": response['requirement'],
        "test_case": response['test_case'],
        "test_report": response['test_report'],
        "coverage_report": response['coverage_report'],
        "mutation_report": response['mutation_report']
    }
    return {"test_cases": [test_report]}


def sendTestPoints(state: PlannerState):
    test_points = state["test_points"]
    if test_points is None:
        return "testPointsGenerator"
    return [Send("testCaseGenerator", {"envServer": state["envServer"],
                                       "allow_human_assistance": state["allow_human_assistance"],
                                       "job_id": state.get("job_id"),
                                       "requirement": test_point,
                                       "package_name": state["package_name"],
                                       "method_signature": state["method_signature"],
                                        "method_code": state["method_code"],
                                        "class_name": state["class_name"],
                                       "start_line": state["start_line"],
                                       "end_line": state["end_line"],
                                       "full_method_name": state["full_method_name"],
                                       "method_summary": state["method_summary"]}) for test_point in test_points]


## 顺序执行生成测试用例
# def sendTestPoints(state: CoordinatorState):
#     envServer = state["envServer"]
#     test_points = state["test_points"]
#     package_name = state["package_name"]
#     method_signature = state["method_signature"]
#     method_code = state["method_code"]
#     class_name = state["class_name"]
#     start_line = state["start_line"]
#     end_line = state["end_line"]
#     method_summary = state["method_summary"]
#     if test_points is None:
#         return "testPointsGenerator"
#     test_cases = []
#     for test_point in test_points:
#         response = executor_graph.invoke({"envServer": envServer,
#                                           "test_point": test_point,
#                                           "package_name": package_name,
#                                           "method_signature": method_signature,
#                                           "method_code": method_code,
#                                           "class_name": class_name,
#                                           "start_line": start_line,
#                                           "end_line": end_line,
#                                           "method_summary": method_summary})
#
#         test_report = {"test_result": response['test_result'],
#                        "find_bug": response['find_bug'],
#                        "test_point": response['test_point'],
#                        "test_case": response['test_case'],
#                        "test_report": response['test_report']}
#         test_cases.append(test_report)
#     return {"test_cases": test_cases}


def testCaseAcceptor(state: PlannerState):
    test_cases = state["test_cases"]
    test_cases_info = state["temp_test_cases"]
    # prompt = Acceptance_Review_Template.invoke({"test_cases_info": test_cases_info})
    # valid_prompt = prompt.to_messages()
    # result = llm.invoke(valid_prompt)
    # test_case = extract_java_test_case(result.content)
    return {"test_cases": test_cases, "package_name": state['package_name'], "method_id": state['method_id'],
            "method_signature": state['method_signature'], "start_line": state['start_line'], "end_line": state['end_line'],}


def acceptorCheck(state: PlannerState):
    test_cases = state["test_cases"]
    if False:
        return "testPointsGenerator"
    return END


graphBuilder = StateGraph(PlannerState)

graphBuilder.add_node("initState", planner_init_state)
graphBuilder.add_node("methodAnalyzer", methodAnalyzer, retry=RetryPolicy(max_attempts=3, retry_on=ValueError))
graphBuilder.add_node("testPointsGenerator", testPointsGenerator,
                      retry=RetryPolicy(max_attempts=3, retry_on=ValueError))
# graphBuilder.add_node("executorGraph", executor_graph)
graphBuilder.add_node("testCaseGenerator", testCaseGenerator)
# graphBuilder.add_node("sendTestPoints", sendTestPoints)
graphBuilder.add_node("testCaseAcceptor", testCaseAcceptor)
# graphBuilder.add_node("compileAndExecuteTest", compileAndExecuteTest)

graphBuilder.add_edge(START, "initState")
graphBuilder.add_edge("initState", "methodAnalyzer")
# graphBuilder.add_conditional_edges("methodAnalyzer", checkSummary,
#                                    {"methodAnalyzer": "methodAnalyzer",
#                                     "testPointsGenerator": "testPointsGenerator"})
# graphBuilder.add_conditional_edges("testPointsGenerator", checkTestPoints,
#                                    {"testPointsGenerator": "testPointsGenerator",
#                                     "executorGraph": "executorGraph"})
graphBuilder.add_edge("methodAnalyzer", "testPointsGenerator")
# graphBuilder.add_conditional_edges("testPointsGenerator", checkTestPoints,["executorGraph"])
# graphBuilder.add_edge("testPointsGenerator", "sendTestPoints")
graphBuilder.add_conditional_edges("testPointsGenerator", sendTestPoints, ["testCaseGenerator"])
graphBuilder.add_edge("testCaseGenerator", "testCaseAcceptor")
# graphBuilder.add_conditional_edges("testCaseAcceptor", acceptorCheck,
#                                    {"testPointsGenerator": "testPointsGenerator",
#                                     END: END})
graphBuilder.add_edge("testCaseAcceptor", END)

graph = graphBuilder.compile()

if __name__ == "__main__":
    try:
        # 生成图片数据
        image_data = graph.get_graph(xray=1).draw_mermaid_png()

        # 保存为本地文件
        with open("Planner.png", "wb") as f:
            f.write(image_data)
    except Exception as e:
        print(f"Failed to save image: {e}")
    print("Graph compiled successfully.")
